﻿#include "ColliderDefinitions.cginc"
#include "QueryDefinitions.cginc"
#include "ContactHandling.cginc"
#include "Transform.cginc"
#include "Simplex.cginc"
#include "Bounds.cginc"
#include "SolverParameters.cginc"
#include "Optimization.cginc"

#pragma kernel GenerateResults

StructuredBuffer<float4> positions;
StructuredBuffer<quaternion> orientations;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<int> simplices;

StructuredBuffer<transform> transforms;
StructuredBuffer<queryShape> shapes;

StructuredBuffer<uint2> contactPairs;
StructuredBuffer<int> contactOffsetsPerType;

RWStructuredBuffer<queryResult> results;
RWStructuredBuffer<uint> dispatchBuffer;

StructuredBuffer<transform> worldToSolver;

uint maxContacts;  

struct Box : IDistanceFunction
{
    queryShape s;
    transform colliderToSolver;
    
    void Evaluate(in float4 pos, in float4 radii, in quaternion orientation, inout SurfacePoint projectedPoint)
    {
        float4 center = s.center * colliderToSolver.scale;
        float4 size = s.size * colliderToSolver.scale * 0.5f;

        // clamp the point to the surface of the box:
        float4 pnt = colliderToSolver.InverseTransformPointUnscaled(pos) - center;

        // get minimum distance for each axis:
        float4 distances = size - abs(pnt);

        if (distances.x >= 0 && distances.y >= 0 && distances.z >= 0)
        {
            projectedPoint.normal = float4(0,0,0,0);
            projectedPoint.pos = pnt;

            // find minimum distance in all three axes and the axis index:        
            if (distances.y < distances.x && distances.y < distances.z)
            {
                projectedPoint.normal[1] = sign(pnt[1]);
                projectedPoint.pos[1] = size[1] * projectedPoint.normal[1];
            }
            else if (distances.z < distances.x && distances.z < distances.y)
            {
                projectedPoint.normal[2] = sign(pnt[2]);
                projectedPoint.pos[2] = size[2] * projectedPoint.normal[2];
            }
            else
            {
                projectedPoint.normal[0] = sign(pnt[0]);
                projectedPoint.pos[0] = size[0] * projectedPoint.normal[0];
            }
        }
        else
        {
            projectedPoint.pos = clamp(pnt, -size, size);
            projectedPoint.normal = normalizesafe(pnt - projectedPoint.pos);
        }

        projectedPoint.pos = colliderToSolver.TransformPointUnscaled(projectedPoint.pos + center + projectedPoint.normal * s.contactOffset);
        projectedPoint.normal = colliderToSolver.TransformDirection(projectedPoint.normal);
        projectedPoint.bary = float4(1,0,0,0);
    }
};

[numthreads(128, 1, 1)]
void GenerateResults (uint3 id : SV_DispatchThreadID) 
{
    uint i = id.x;

    // entry #11 in the dispatch buffer is the amount of pairs for the first shape type.
    if (i >= dispatchBuffer[11 + 4*BOX_QUERY]) return; 
    
    int firstPair = contactOffsetsPerType[BOX_QUERY];
    int simplexIndex = contactPairs[firstPair + i].x;
    int queryIndex = contactPairs[firstPair + i].y;

    queryResult c = (queryResult)0;

    Box boxShape;
    boxShape.colliderToSolver = worldToSolver[0].Multiply(transforms[queryIndex]);
    boxShape.s = shapes[queryIndex];

    int simplexSize;
    int simplexStart = GetSimplexStartAndSize(simplexIndex, simplexSize);

    float4 simplexBary = BarycenterForSimplexOfSize(simplexSize);
    float4 simplexPoint;

    SurfacePoint surfacePoint = Optimize(boxShape, positions, orientations, principalRadii,
                                         simplices, simplexStart, simplexSize, simplexBary, simplexPoint, surfaceCollisionIterations, surfaceCollisionTolerance);


    float4 simplexPrevPosition = FLOAT4_ZERO;
    float simplexRadius = 0;
    for (int j = 0; j < simplexSize; ++j)
    {
        int particleIndex = simplices[simplexStart + j];
        simplexPrevPosition += positions[particleIndex] * simplexBary[j];
        simplexRadius += EllipsoidRadius(surfacePoint.normal, orientations[particleIndex], principalRadii[particleIndex].xyz) * simplexBary[j];
    }

    c.queryPoint = surfacePoint.pos;
    c.normal = surfacePoint.normal;
    c.simplexBary = simplexBary;
    c.simplexIndex = simplexIndex;
    c.queryIndex = queryIndex;
    c.dist = dot(simplexPrevPosition - surfacePoint.pos,surfacePoint.normal) - simplexRadius;

    if (c.dist <= boxShape.s.maxDistance)
    {
        uint count = results.IncrementCounter();
        if (count < maxContacts)
        {
            results[count] = c;
     
            InterlockedMax(dispatchBuffer[0],(count + 1) / 128 + 1);
            InterlockedMax(dispatchBuffer[3], count + 1);
        }
    }
    
}